package com.snail.common.datascope.handler;

import com.snail.common.datascope.annotation.DataScope;
import com.snail.common.datascope.utils.DataScopeUtils;
import com.snail.common.security.utils.SecurityUtils;
import lombok.SneakyThrows;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.HexValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;

import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Optional;

/**
 * @Description: 数据范围权限Handler
 * @Author: Snail
 * @CreateDate: 2023/8/10 17:02
 * @Version: V1.0
 */
public class DataScopePermissionHandler {

    @SneakyThrows
    public Expression getSqlSegment(PlainSelect plainSelect, String whereSegment) {
        Expression where = plainSelect.getWhere();
        //系统管理员比加任何条件
        if (SecurityUtils.isAdminUser()) {
            return where;
        }
        //1.获取mapper名称
        String className = whereSegment.substring(0, whereSegment.lastIndexOf("."));
        //2.获取mapper对应的方法
        Class<?> aClass = Class.forName(className);
        //类是否有数据权限注释
        DataScope dataScope = aClass.getAnnotation(DataScope.class);
        if(dataScope == null){
            return where;
        }
        //获取sql的别名
        Table table = (Table) plainSelect.getFromItem();
        Alias alias = table.getAlias();
        String aliasName = alias == null ? table.getName() : alias.getName();
        //当前类的方法
        Method[] methods = aClass.getMethods();
        //3.获取当前执行方法名称
        String methodName = whereSegment.substring(whereSegment.lastIndexOf(".") + 1);
        //4.遍历mapper方法找到当前方法名称
        Optional<Method> first = Arrays.stream(methods).filter(item -> item.getName().equals(methodName)).findFirst();
        //没有匹配的方法名称直接返回
        if (!first.isPresent()) {
            return where;
        }
        Method method = first.get();
        //注解在类上面,当前方法
        if (dataScope != null && Arrays.asList(dataScope.includeMethod()).contains(methodName)) {
            return buildExpression(where, aliasName, dataScope);
        }
        //5.获取方法上面是否有数据权限的注解
        dataScope = method.getAnnotation(DataScope.class);
        //如果方法有注解添加权限
        if (dataScope != null) {
            return buildExpression(where, aliasName, dataScope);
        }
        return where;
    }

    /**
     * 封装sql条件信息
     *
     * @param where     条件
     * @param aliasName 别名
     * @param dataScope 数据权限
     * @return 结果
     */
    private Expression buildExpression(Expression where, String aliasName, DataScope dataScope) throws SQLException {
        if (where == null) {
            where = new HexValue(" 1=1 ");
        }
        //数据权限
        Expression expression = DataScopeUtils.authExpression(dataScope, aliasName);
        if (expression == null) {
            return where;
        }
        return new AndExpression(where, expression);
    }
}
